#!/usr/bin/env python3
"""
generate_flip_counts
====================

Generate a per‑link array of tick‑flip counts for an L×L periodic lattice.
This module generalises the Volume‑4 flip‑count simulator to arbitrary
lattice size.  It implements the tick‑flip operator algebra directly
using the definitions from the private ``ar‑operator‑core`` repository.

The primary entry point is :func:`generate_flip_counts`, which returns a
NumPy array of length ``2*L*L``.  A simple CLI interface is provided via
the ``main`` function; run ``python generate_flip_counts.py --help`` for
usage.
"""

import argparse
import os
from dataclasses import dataclass
import numpy as np


@dataclass(frozen=True)
class TickState:
    """Container for a tick distribution and its context depth N."""

    distribution: np.ndarray
    N: int

    def __post_init__(self):
        if not isinstance(self.distribution, np.ndarray):
            raise TypeError("distribution must be a numpy.ndarray")
        if self.distribution.ndim != 1:
            raise ValueError("distribution must be a 1D array")
        expected_len = 2 * self.N + 1
        if self.distribution.size != expected_len:
            raise ValueError(
                f"distribution length must be 2*N+1 ({expected_len}), got {self.distribution.size}"
            )


def renewal(state: TickState) -> TickState:
    """Apply the Renewal (F) operator: move mass outward."""
    dist = state.distribution
    L = dist.size
    new = np.zeros_like(dist)
    if L > 1:
        new[0] = dist[0] + dist[1]
    else:
        new[0] = dist[0]
    if L > 2:
        new[1 : L - 1] = dist[2:]
    if L > 0:
        new[L - 1] = 0.0
    return TickState(new, state.N)


def F(state: TickState) -> TickState:
    return renewal(state)


def S(state: TickState) -> TickState:
    """Apply the Sink operator: move mass inward."""
    dist = state.distribution
    L = dist.size
    new = np.zeros_like(dist)
    if L > 2:
        new[1 : L - 1] = dist[0 : L - 2]
    if L > 1:
        new[L - 1] = dist[L - 1] + dist[L - 2]
    else:
        new[0] = dist[0]
    return TickState(new, state.N)


def X(state: TickState) -> TickState:
    """Apply the Distinction operator: reverse the distribution."""
    return TickState(state.distribution[::-1].copy(), state.N)


def C(state: TickState) -> TickState:
    """Apply the Sync operator: average distribution with its reverse."""
    dist = state.distribution
    rev = dist[::-1]
    return TickState(0.5 * (dist + rev), state.N)


def Phi(state: TickState) -> TickState:
    """Apply the Frame coupling operator: Φ = C ∘ X."""
    return C(X(state))


def build_default_lattice(size: int, boundary: str = "periodic") -> np.ndarray:
    """
    Construct a 2D periodic lattice of side length ``size``.  Returns an array
    of ``((x, y), mu)`` tuples with length ``2*size*size``.
    """
    links = []
    directions = [(1, 0), (0, 1)]  # mu=0: +x, mu=1: +y
    for x in range(size):
        for y in range(size):
            for mu, (dx, dy) in enumerate(directions):
                nx, ny = x + dx, y + dy
                if boundary == "periodic":
                    nx %= size
                    ny %= size
                else:
                    if not (0 <= nx < size and 0 <= ny < size):
                        continue
                links.append(((x, y), mu))
    return np.array(links, dtype=object)


def count_flips_on_link(link: tuple, sim_params: dict) -> int:
    """
    For a single link, perform a short random walk over the tick‑flip operators
    and count changes at a link‑specific watch index.

    A link‑specific RNG (seed supplied via ``sim_params['seed']``) is used
    to randomise the order of operators on each step.  This ensures that
    distinct links do not follow identical deterministic operator sequences.
    """
    N = sim_params["N"]
    steps = sim_params["steps_per_link"]
    seed = sim_params.get("seed", None)
    # Each link receives its own RNG; if ``seed`` is None then this RNG
    # draws from the global entropy pool.  Otherwise a deterministic
    # ``Generator`` is created from the given seed.  See generate_flip_counts
    # for how seeds are derived per link.
    rng = np.random.default_rng(seed)

    # Initialize delta‑spike distribution centred at index N
    dist0 = np.zeros(2 * N + 1)
    centre = N
    dist0[centre] = 1.0
    state = TickState(dist0, N)

    # Map the link to a watch index.  Different links will monitor
    # different positions in the distribution depending on their
    # lattice coordinates.
    ((x, y), mu) = link
    watch_idx = (x + y + mu) % (2 * N + 1)

    flip_count = 0
    # Initialise operator list once; will be shuffled in‑place each step
    ops = [F, S, X, C, Phi]

    for _ in range(steps):
        # Randomise the order of operators for this step.  Using ``rng``
        # ensures deterministic reproducibility when seeded, and variation
        # across links otherwise.
        rng.shuffle(ops)
        for op in ops:
            new_state = op(state)
            # Count a flip when the watch index changes from the prior state
            if not np.isclose(
                new_state.distribution[watch_idx], state.distribution[watch_idx]
            ):
                flip_count += 1
            state = new_state
    return flip_count


def generate_flip_counts(
    lattice_size: int,
    seed: int | None,
    context_depth: int,
    steps_per_link: int,
) -> np.ndarray:
    """
    Generate flip counts for all links on an ``lattice_size``×``lattice_size`` lattice via
    tick‑flip operator walks.

    :param lattice_size: side length of the square lattice
    :param seed: random seed for reproducibility (may be None)
    :param context_depth: context depth (half‑range of distribution)
    :param steps_per_link: number of operator sequences per link
    :returns: NumPy array of length ``2*L*L`` containing integer flip counts
    """
    lattice = build_default_lattice(lattice_size, boundary="periodic")
    flip_counts = np.zeros(len(lattice), dtype=int)
    # Base parameters common to all links.  We deliberately avoid storing a
    # single global RNG here; instead each link will receive its own seed
    # derived from the provided ``seed``.  This avoids identical operator
    # sequences across links and introduces variability necessary for
    # meaningful correlation analysis.
    base_params = {
        "N": context_depth,
        "steps_per_link": steps_per_link,
    }
    # Prepare an RNG for deriving per‑link seeds.  If ``seed`` is None we
    # simply pass None through to ``count_flips_on_link`` for true
    # randomness; otherwise we create a deterministic sequence of seeds
    # by adding a large prime multiple of the link index.  The prime 9973
    # was chosen arbitrarily for good dispersion.
    for idx, link in enumerate(lattice):
        if seed is None:
            link_seed = None
        else:
            link_seed = seed + idx * 9973
        sim_params = base_params | {"seed": link_seed}
        flip_counts[idx] = count_flips_on_link(link, sim_params)
    return flip_counts


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate tick‑flip counts for an L×L lattice",
    )
    parser.add_argument(
        "--lattice-size",
        "-L",
        type=int,
        default=4,
        help="Lattice side length (e.g. 6)",
    )
    parser.add_argument(
        "--seed",
        "-s",
        type=int,
        default=None,
        help="Random seed for reproducibility",
    )
    parser.add_argument(
        "--context-depth",
        "-N",
        type=int,
        default=2,
        help="Context depth (half‑range of distribution)",
    )
    parser.add_argument(
        "--steps-per-link",
        "-t",
        type=int,
        default=1000,
        help="Number of operator sequences per link",
    )
    parser.add_argument(
        "--output",
        "-o",
        default="data/flip_counts.npy",
        help="Output .npy file path",
    )
    args = parser.parse_args()

    counts = generate_flip_counts(
        lattice_size=args.lattice_size,
        seed=args.seed,
        context_depth=args.context_depth,
        steps_per_link=args.steps_per_link,
    )

    out_dir = os.path.dirname(args.output)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
    # assert shape correctness
    expected_len = 2 * args.lattice_size * args.lattice_size
    if counts.size != expected_len:
        raise RuntimeError(
            f"Generated flip counts have length {counts.size}, expected {expected_len}"
        )
    np.save(args.output, counts)
    print(
        f"Saved flip counts for {args.lattice_size}×{args.lattice_size} lattice to {args.output} (length {len(counts)})"
    )


if __name__ == "__main__":
    main()